import os

from dotenv import load_dotenv
from langchain import LLMChain
from langchain.chains import StuffDocumentsChain
from langchain.document_transformers import LongContextReorder
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.vectorstores import Chroma

# 使用 huggingface 托管的开源LLM来做嵌入，MiniLM-L6-v2是一个较小的LLM
embeddings = HuggingFaceBgeEmbeddings(model_name="all-MiniLM-L6-v2")

text = [
    "篮球是一项伟大的运动。",
    "带我飞往月球是我最喜欢的歌曲之一。",
    "凯尔特人队是我最喜欢的球队。",
    "这是一篇关于波士顿凯尔特人的文件。",
    "我非常喜欢去看电影。",
    "波士顿凯尔特人队以20分的优势赢得了比赛。",
    "这只是一段随机的文字。",
    "《艾尔登之环》是过去15年最好的游戏之一。",
    "L.科内特是凯尔特人队最好的球员之一。",
    "拉里.伯德是一位标志性的NBA球员。"
]

retrieval = Chroma.from_texts(text, embeddings).as_retriever(
    search_kwargs={"k": 10}
)
query = "关于我的喜好都知道什么？"

# 根据相关性返回文本块
docs = retrieval.get_relevant_documents(query)
print(docs)

print("-------------------------------------------------------------")

# 对检索结果进行重新排序，根据论文的方案
# 问题相关性越低的内容块放在中间
# 问题相关性越高的内容块放在头尾
reordering = LongContextReorder()
reo_docs = reordering.transform_documents(docs)

# 头尾共有4个高相关性内容块
print(reo_docs)

print("-------------------------------------------------------------")

# 检测下这种方案的精度效果
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI

load_dotenv("../ai.env")

api_base = os.getenv("OPENAI_API_BASE")
api_key = os.getenv("OPENAI_KEY")

# 设置llm
llm = OpenAI(
    openai_api_key=api_key,
    openai_api_base=api_base,
    model="gpt-3.5-turbo-instruct",
    temperature=0
)

document_prompt = PromptTemplate(
    input_variables=["page_content"], template="{page_content}"
)

stuff_prompt_override = """Given this text extracts:
_________________________________________
{context}
_________________________________________
Please answer the following questions:
{query}
"""

prompt = PromptTemplate(
    template=stuff_prompt_override,
    input_variables=["context", "query"]
)

llm_chain = LLMChain(
    llm=llm,
    prompt=prompt
)

WorkChain = StuffDocumentsChain(
    llm_chain=llm_chain,
    document_prompt=document_prompt,
    document_variable_name="context"
)

# 调用
print(WorkChain.run(
    input_documents=reo_docs,
    query="我最喜欢做什么事情？"
))
